import torch
from torch import nn

"""
残差与层归一化结构
残差：梯度消失问题
归一化：梯度爆炸问题
"""


class AddNorm(nn.Module):
    def __init__(self, hidden_dims, dropout):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.norm = nn.LayerNorm(hidden_dims)

    def forward(self, x1, x2):
        return self.dropout(x1 + self.norm(x2))
